#!/usr/bin/env python
import numpy as np
import pandas as pd
import sys
import os

# Midpoints of the R_G bins (kpc)
RG_MID = {
    "1.5–3.0": 2.25,
    "3.0–5.0": 4.0,
    "5.0–8.0": 6.5,
    "8.0–12.0": 10.0,
}

# Central log10(M*) for each stellar-mass bin
MASS_CENTERS = {
    "10.2–10.5": 10.35,
    "10.5–10.8": 10.65,
    "10.8–11.1": 10.95,
}


def load_plateau(path: str) -> pd.DataFrame:
    """Load plateau results and attach RG_mid, mstar_cen, sigma."""
    df = pd.read_csv(path)

    # Only claimable stacks
    if "claimable" in df.columns:
        df = df[df["claimable"] == True].copy()
    else:
        df = df.copy()

    df["RG_mid"] = df["R_G_bin"].map(RG_MID)
    df["mstar_cen"] = df["Mstar_bin"].map(MASS_CENTERS)

    # Approximate 1σ from 16–84% CI
    if {"A_theta_CI_low", "A_theta_CI_high"}.issubset(df.columns):
        sig = (df["A_theta_CI_high"] - df["A_theta_CI_low"]) / 2.0
    else:
        sig = pd.Series(1e-3, index=df.index)

    sig = sig.replace([np.inf, -np.inf], np.nan)
    if (sig > 0).any():
        sig = sig.fillna(sig[sig > 0].median())
        sig[sig <= 0] = sig[sig > 0].min()
    else:
        sig = pd.Series(1e-3, index=df.index)

    df["sigma"] = sig
    return df


def draw_dom(df: pd.DataFrame, draws: int = 50000, seed: int = 123):
    """Draw Δ_out-mid samples per mass bin."""
    rng = np.random.default_rng(seed)
    out = []

    for mbin, mcen in MASS_CENTERS.items():
        g = df[df["Mstar_bin"] == mbin]

        need = {}
        for _, r in g.iterrows():
            rg_mid = r["RG_mid"]
            if rg_mid in (4.0, 6.5, 10.0):
                need[rg_mid] = (r["A_theta"], r["sigma"])

        if not all(k in need for k in (4.0, 6.5, 10.0)):
            out.append((mbin, mcen, None))
            continue

        a4 = rng.normal(need[4.0][0], need[4.0][1], draws)
        a65 = rng.normal(need[6.5][0], need[6.5][1], draws)
        a10 = rng.normal(need[10.0][0], need[10.0][1], draws)

        dom = a10 - 0.5 * (a4 + a65)
        out.append((mbin, mcen, dom))

    return out


def fit_hinge(dom_by_bin):
    """
    Fit Δ(m) = β0 + β1 [m - m_c]_+ with β1 >= 0
    via grid search over m_c and per-draw least squares.
    """
    avail = [(m, mc, d) for (m, mc, d) in dom_by_bin if d is not None]
    if len(avail) < 2:
        return np.array([np.nan]), np.array([np.nan])

    m = np.array([mc for (_, mc, _) in avail])
    D = np.vstack([d for (_, _, d) in avail]).T  # shape: (draws, nbins)

    mc_grid = np.linspace(10.3, 11.0, 71)  # step 0.01
    best_mc = np.empty(D.shape[0])
    best_b1 = np.empty(D.shape[0])

    for i in range(D.shape[0]):
        y = D[i]
        sse_best = np.inf
        mc_best = np.nan
        b1_best = np.nan

        for mc in mc_grid:
            X1 = np.ones_like(m)
            X2 = np.maximum(0.0, m - mc)
            A = np.vstack([X1, X2]).T

            try:
                b0, b1 = np.linalg.lstsq(A, y, rcond=None)[0]
            except Exception:
                continue

            if b1 < 0.0:
                b1 = 0.0

            yhat = b0 + b1 * X2
            sse = np.sum((y - yhat) ** 2)

            if sse < sse_best:
                sse_best = sse
                mc_best = mc
                b1_best = b1

        best_mc[i] = mc_best
        best_b1[i] = b1_best

    return best_mc, best_b1


def main():
    if len(sys.argv) > 1:
        path = sys.argv[1]
    else:
        path = "outputs/lensing_plateau.csv"

    df = load_plateau(path)
    dom_by_bin = draw_dom(df)

    mc, b1 = fit_hinge(dom_by_bin)

    if np.all(np.isnan(mc)):
        print("Insufficient coverage to fit hinge (need outer + mids in ≥ 2 mass bins).")
        return

    p_b1pos = float(np.mean(b1 > 0))
    mc_med = float(np.nanmedian(mc))
    mc_lo = float(np.nanpercentile(mc, 16))
    mc_hi = float(np.nanpercentile(mc, 84))
    p_mc_mw = float(np.mean((mc >= 10.6) & (mc <= 10.8)))

    tag = os.path.splitext(os.path.basename(path))[0]
    print(f"=== {tag} ===")
    print(f"P(beta1>0) = {p_b1pos:.3f}")
    print(f"m_c median [16,84] = {mc_med:.3f}  [{mc_lo:.3f}, {mc_hi:.3f}]")
    print(f"P(m_c in [10.6,10.8]) = {p_mc_mw:.3f}")


if __name__ == "__main__":
    main()
